{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Azure OpenAI for big data\n",
    "\n",
    "The Azure OpenAI service can be used to solve a large number of natural language tasks through prompting the completion API. To make it easier to scale your prompting workflows from a few examples to large datasets of examples, we have integrated the Azure OpenAI service with the distributed machine learning library [SynapseML](https://www.microsoft.com/en-us/research/blog/synapseml-a-simple-multilingual-and-massively-parallel-machine-learning-library/). This integration makes it easy to use the [Apache Spark](https://spark.apache.org/) distributed computing framework to process millions of prompts with the OpenAI service. This tutorial shows how to apply large language models at a distributed scale using Azure OpenAI. "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "The key prerequisites for this quickstart include a working Azure OpenAI resource, and an Apache Spark cluster with SynapseML installed. We suggest creating a Synapse workspace, but an Azure Databricks, HDInsight, or Spark on Kubernetes, or even a python environment with the `pyspark` package will work. \n",
    "\n",
    "1. An Azure OpenAI resource – request access [here](https://customervoice.microsoft.com/Pages/ResponsePage.aspx?id=v4j5cvGGr0GRqy180BHbR7en2Ais5pxKtso_Pz4b1_xUOFA5Qk1UWDRBMjg0WFhPMkIzTzhKQ1dWNyQlQCN0PWcu) before [creating a resource](https://docs.microsoft.com/en-us/azure/cognitive-services/openai/how-to/create-resource?pivots=web-portal#create-a-resource)\n",
    "1. [Create a Synapse workspace](https://docs.microsoft.com/en-us/azure/synapse-analytics/get-started-create-workspace)\n",
    "1. [Create a serverless Apache Spark pool](https://docs.microsoft.com/en-us/azure/synapse-analytics/get-started-analyze-spark#create-a-serverless-apache-spark-pool)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import this guide as a notebook\n",
    "\n",
    "The next step is to add this code into your Spark cluster. You can either create a notebook in your Spark platform and copy the code into this notebook to run the demo. Or download the notebook and import it into Synapse Analytics\n",
    "\n",
    "-\t[Download this demo as a notebook](https://github.com/microsoft/SynapseML/blob/master/docs/Explore%20Algorithms/OpenAI/OpenAI.ipynb) (select **Raw**, then save the file)\n",
    "-\tImport the notebook. \n",
    "    * If you are using Synapse Analytics [into the Synapse Workspace](https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/apache-spark-development-using-notebooks#create-a-notebook) \n",
    "    * If your are using Databricks [import into the Databricks Workspace](https://docs.microsoft.com/en-us/azure/databricks/notebooks/notebooks-manage#create-a-notebook). \n",
    "    * If you are using Fabric [import into the Fabric Workspace](https://learn.microsoft.com/en-us/fabric/data-engineering/how-to-use-notebook)\n",
    "-   Install SynapseML on your cluster. See the installation instructions for Synapse at the bottom of [the SynapseML website](https://microsoft.github.io/SynapseML/). \n",
    "    * If you are using Fabric, please check [Installation Guide](https://learn.microsoft.com/en-us/fabric/data-science/install-synapseml). This requires pasting an extra cell at the top of the notebook you imported. \n",
    "-  \tConnect your notebook to a cluster and follow along, editing and running the cells."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fill in service information\n",
    "\n",
    "Next, edit the cell in the notebook to point to your service. In particular set the `service_name`, `deployment_name`, `location`, and `key` variables to match them to your OpenAI service:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.core.platform import find_secret\n",
    "\n",
    "# Fill in the following lines with your service information\n",
    "# Learn more about selecting which embedding model to choose: https://openai.com/blog/new-and-improved-embedding-model\n",
    "service_name = \"synapseml-openai-2\"\n",
    "deployment_name = \"gpt-4.1-mini\"\n",
    "deployment_name_embeddings = \"text-embedding-ada-002\"\n",
    "\n",
    "key = find_secret(\n",
    "    secret_name=\"openai-api-key-2\", keyvault=\"mmlspark-build-keys\"\n",
    ")  # please replace this line with your key as a string\n",
    "\n",
    "assert key is not None and service_name is not None"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a dataset of prompts\n",
    "\n",
    "Next, create a dataframe consisting of a series of rows, with one prompt per row. \n",
    "\n",
    "You can also load data directly from ADLS or other databases. For more information on loading and preparing Spark dataframes, see the [Apache Spark data loading guide](https://spark.apache.org/docs/latest/sql-data-sources.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# spark session is assumed to be created in the environment already such as in Fabric notebooks with Spark environment\n",
    "df = spark.createDataFrame(\n",
    "    [\n",
    "        (\"Hello my name is\",),\n",
    "        (\"The best code is code thats\",),\n",
    "        (\"SynapseML is \",),\n",
    "    ]\n",
    ").toDF(\"prompt\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More Usage Examples"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating Text Embeddings\n",
    "\n",
    "In addition to completing text, we can also embed text for use in downstream algorithms or vector retrieval architectures. Creating embeddings allows you to search and retrieve documents from large collections and can be used when prompt engineering isn't sufficient for the task."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more information on using `OpenAIEmbedding` see our [embedding guide](./Quickstart%20-%20OpenAI%20Embedding)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIEmbedding\n",
    "\n",
    "embedding = (\n",
    "    OpenAIEmbedding()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name_embeddings)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setTextCol(\"prompt\")\n",
    "    .setErrorCol(\"error\")\n",
    "    .setOutputCol(\"embeddings\")\n",
    ")\n",
    "\n",
    "display(embedding.transform(df).show(truncate=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chat Completion\n",
    "\n",
    "Models such as ChatGPT and GPT-4 are capable of understanding chats instead of single prompts. The `OpenAIChatCompletion` transformer exposes this functionality at scale."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIChatCompletion\n",
    "from pyspark.sql import Row\n",
    "from pyspark.sql.types import *\n",
    "\n",
    "\n",
    "def make_message(role, content):\n",
    "    return Row(role=role, content=content, name=role)\n",
    "\n",
    "\n",
    "chat_df = spark.createDataFrame(\n",
    "    [\n",
    "        (\n",
    "            [\n",
    "                make_message(\n",
    "                    \"system\", \"You are an AI chatbot with red as your favorite color\"\n",
    "                ),\n",
    "                make_message(\"user\", \"Whats your favorite color\"),\n",
    "            ],\n",
    "        ),\n",
    "        (\n",
    "            [\n",
    "                make_message(\"system\", \"You are very excited\"),\n",
    "                make_message(\"user\", \"How are you today\"),\n",
    "            ],\n",
    "        ),\n",
    "    ]\n",
    ").toDF(\"messages\")\n",
    "\n",
    "\n",
    "chat_completion = (\n",
    "    OpenAIChatCompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMessagesCol(\"messages\")\n",
    "    .setOutputCol(\"chat_completions\")\n",
    "    .setErrorCol(\"chat_completions_error\")\n",
    ")\n",
    "\n",
    "display(\n",
    "    chat_completion.transform(chat_df)\n",
    "    .select(\"messages\", \"chat_completions.choices.message.content\")\n",
    "    .show(truncate=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chat Completion - Advanced Parameters for Reproducible Outputs\n",
    "\n",
    "SynapseML now supports additional parameters for enhanced control over OpenAI model behavior for reproducible outputs:\n",
    "\n",
    "- **`temperature`**: Reduces randomness. OpenAI models accept float temperature value between [0, 2]. Set to 0 for best reproducibility.\n",
    "- **`top_p`**: Controls nucleus sampling as an alternative to temperature. OpenAI models accept float top_p value between [0, 1]. Set close to 0 for best reproducibility.\n",
    "- **`seed`**: Enables deterministic sampling for reproducible results. Set to any constant int value.\n",
    "\n",
    "\n",
    "These parameters can be set globally using `OpenAIDefaults` or on individual transformer instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIChatCompletion\n",
    "from synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults\n",
    "\n",
    "# Set global defaults including new parameters\n",
    "defaults = OpenAIDefaults()\n",
    "defaults.set_deployment_name(deployment_name)\n",
    "defaults.set_subscription_key(key)\n",
    "defaults.set_URL(f\"https://{service_name}.openai.azure.com/\")\n",
    "defaults.set_temperature(0)\n",
    "defaults.set_top_p(0.1)\n",
    "defaults.set_seed(42)\n",
    "\n",
    "chat_completion = (\n",
    "    OpenAIChatCompletion()\n",
    "    .setMessagesCol(\"messages\")\n",
    "    .setOutputCol(\"chat_completions\")\n",
    "    .setErrorCol(\"chat_completions_error\")\n",
    ")\n",
    "\n",
    "display(\n",
    "    chat_completion.transform(chat_df)\n",
    "    .select(\"messages\", \"chat_completions.choices.message.content\")\n",
    "    .show(truncate=False)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Legacy) Create the OpenAICompletion Apache Spark Client\n",
    "\n",
    "To apply the OpenAI Completion service to your dataframe you created, create an OpenAICompletion object, which serves as a distributed client. Parameters of the service can be set either with a single value, or by a column of the dataframe with the appropriate setters on the `OpenAICompletion` object. Here we're setting `maxTokens` to 200. A token is around four characters, and this limit applies to the sum of the prompt and the result. We're also setting the `promptCol` parameter with the name of the prompt column in the dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAICompletion\n",
    "\n",
    "completion = (\n",
    "    OpenAICompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMaxTokens(200)\n",
    "    .setPromptCol(\"prompt\")\n",
    "    .setErrorCol(\"error\")\n",
    "    .setOutputCol(\"completions\")\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Legacy) Transform the dataframe with the OpenAICompletion Client\n",
    "\n",
    "After creating the dataframe and the completion client, you can transform your input dataset and add a column called `completions` with all of the information the service adds. Select just the text for simplicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import col\n",
    "\n",
    "completed_df = completion.transform(df).cache()\n",
    "display(\n",
    "    completed_df.select(\n",
    "        col(\"prompt\"),\n",
    "        col(\"error\"),\n",
    "        col(\"completions.choices.text\").getItem(0).alias(\"text\"),\n",
    "    ).show(truncate=False)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your output should look something like this. The completion text will be different from the sample.\n",
    "\n",
    "| **prompt**                   \t| **error** \t| **text**                                                                                                                              \t|\n",
    "|:----------------------------:\t|:----------:\t|:-------------------------------------------------------------------------------------------------------------------------------------:\t|\n",
    "| Hello my name is            \t| null      \t| Makaveli I'm eighteen years old and I want to   be a rapper when I grow up I love writing and making music I'm from Los   Angeles, CA \t|\n",
    "| The best code is code thats \t| null      \t| understandable This is a subjective statement,   and there is no definitive answer.                                                   \t|\n",
    "| SynapseML is                \t| null      \t| A machine learning algorithm that is able to learn how to predict the future outcome of events.                                       \t|"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Improve throughput with request batching for OpenAICompletion\n",
    "\n",
    "The example makes several requests to the service, one for each prompt. To complete multiple prompts in a single request, use batch mode. First, in the OpenAICompletion object, instead of setting the Prompt column to \"Prompt\", specify \"batchPrompt\" for the BatchPrompt column.\n",
    "To do so, create a dataframe with a list of prompts per row.\n",
    "\n",
    "As of this writing there's currently a limit of 20 prompts in a single request, and a hard limit of 2048 \"tokens\", or approximately 1500 words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_df = spark.createDataFrame(\n",
    "    [\n",
    "        ([\"The time has come\", \"Pleased to\", \"Today stocks\", \"Here's to\"],),\n",
    "        ([\"The only thing\", \"Ask not what\", \"Every litter\", \"I am\"],),\n",
    "    ]\n",
    ").toDF(\"batchPrompt\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we create the OpenAICompletion object. Rather than setting the prompt column, set the batchPrompt column if your column is of type `Array[String]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_completion = (\n",
    "    OpenAICompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMaxTokens(200)\n",
    "    .setBatchPromptCol(\"batchPrompt\")\n",
    "    .setErrorCol(\"error\")\n",
    "    .setOutputCol(\"completions\")\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the call to transform, a request will be made per row. Since there are multiple prompts in a single row, each request is sent with all prompts in that row. The results contain a row for each row in the request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "completed_batch_df = batch_completion.transform(batch_df).cache()\n",
    "display(completed_batch_df.show(truncate=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using an automatic minibatcher\n",
    "\n",
    "If your data is in column format, you can transpose it to row format using SynapseML's `FixedMiniBatcherTransformer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.types import StringType\n",
    "from synapse.ml.stages import FixedMiniBatchTransformer\n",
    "from synapse.ml.core.spark import FluentAPI\n",
    "\n",
    "completed_autobatch_df = (\n",
    "    df.coalesce(\n",
    "        1\n",
    "    )  # Force a single partition so that our little 4-row dataframe makes a batch of size 4, you can remove this step for large datasets\n",
    "    .mlTransform(FixedMiniBatchTransformer(batchSize=4))\n",
    "    .withColumnRenamed(\"prompt\", \"batchPrompt\")\n",
    "    .mlTransform(batch_completion)\n",
    ")\n",
    "\n",
    "display(completed_autobatch_df.show(truncate=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompt engineering for translation\n",
    "\n",
    "The Azure OpenAI service can solve many different natural language tasks through [prompt engineering](https://docs.microsoft.com/en-us/azure/cognitive-services/openai/how-to/completions). Here, we show an example of prompting for language translation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translate_df = spark.createDataFrame(\n",
    "    [\n",
    "        (\"Japanese: Ookina hako English: Big box Japanese: Midori takoEnglish:\",),\n",
    "        (\n",
    "            \"French: Quel heure et il au Montreal? English: What time is it in Montreal? French: Ou est le poulet? English:\",\n",
    "        ),\n",
    "    ]\n",
    ").toDF(\"prompt\")\n",
    "\n",
    "display(completion.transform(translate_df).show(truncate=False))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prompt for question answering\n",
    "\n",
    "Here, we prompt GPT-3 for general-knowledge question answering:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qa_df = spark.createDataFrame(\n",
    "    [\n",
    "        (\n",
    "            \"Q: Where is the Grand Canyon?A: The Grand Canyon is in Arizona.Q: What is the weight of the Burj Khalifa in kilograms?A:\",\n",
    "        )\n",
    "    ]\n",
    ").toDF(\"prompt\")\n",
    "\n",
    "display(completion.transform(qa_df).show(truncate=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Structured JSON output with json_schema\n",
    "\n",
    "Use the new `responseFormat` option `json_schema` to force the model to return JSON strictly matching a provided JSON Schema. This improves reliability for downstream parsing and analytics.\n",
    "\n",
    "**Key points:**\n",
    "- Set `responseFormat` to a dict containing `type=\"json_schema\"` and a nested `json_schema` object (JSON string form not supported).\n",
    "- Provide the schema as a Python dict / Scala Map plus a `name` and optional `strict` flag.\n",
    "- Bare `json_schema` string is rejected; supply the full dict form.\n",
    "- Allowed simple String values for `responseFormat`: `text`, `json_object`.\n",
    "- For `json_schema` minimal validation only checks the presence of the `json_schema` key; the nested schema is passed through unchanged.\n",
    "- The model output will be constrained to the schema (no extra properties when `additionalProperties: false`).\n",
    "\n",
    "**Summary table**\n",
    "\n",
    "| Type        | How to set             | Requires nested schema | Notes |\n",
    "|-------------|------------------------|------------------------|-------|\n",
    "| text        | String(\"text\")         | No                     | Raw string output |\n",
    "| json_object | String(\"json_object\")  | No                     | Model attempts well‑formed JSON (not strictly validated) |\n",
    "| json_schema | Dict/Map only          | Yes                    | Strict; reject bare string or JSON string form |\n",
    "\n",
    "Below we request a single field `answer` as structured JSON."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIChatCompletion\n",
    "from pyspark.sql import Row\n",
    "\n",
    "# Define the JSON Schema we want the model to satisfy\n",
    "schema = {\n",
    "    \"type\": \"object\",\n",
    "    \"properties\": {\"answer\": {\"type\": \"string\"}},\n",
    "    \"required\": [\"answer\"],\n",
    "    \"additionalProperties\": False,\n",
    "}\n",
    "\n",
    "# Single user message requesting structured output\n",
    "messages_df = spark.createDataFrame(\n",
    "    [\n",
    "        (\n",
    "            [\n",
    "                Row(\n",
    "                    role=\"user\",\n",
    "                    content=\"What is the capital of France?\",\n",
    "                )\n",
    "            ],\n",
    "        )\n",
    "    ]\n",
    ").toDF(\"messages\")\n",
    "\n",
    "chat_structured = (\n",
    "    OpenAIChatCompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMessagesCol(\"messages\")\n",
    "    .setResponseFormat(\n",
    "        {\n",
    "            \"type\": \"json_schema\",\n",
    "            \"json_schema\": {\n",
    "                \"name\": \"answer_schema\",\n",
    "                \"strict\": True,\n",
    "                \"schema\": schema,\n",
    "            },\n",
    "        }\n",
    "    )\n",
    "    .setOutputCol(\"chat_structured\")\n",
    "    .setErrorCol(\"chat_structured_error\")\n",
    ")\n",
    "\n",
    "display(\n",
    "    chat_structured.transform(messages_df)\n",
    "    .select(\"chat_structured.choices.message.content\")\n",
    "    .show(truncate=False)\n",
    ")\n",
    "# The returned content should be a JSON object: {\\\"answer\\\": \\\"Paris\\\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ResponseFormat Options Quick Tests\n",
    "\n",
    "Below we validate `text`, `json_object`, and `json_schema` across APIs in separate cells for clarity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIChatCompletion\n",
    "from pyspark.sql import Row\n",
    "\n",
    "messages_df = spark.createDataFrame(\n",
    "    [([Row(role=\"user\", content=\"Say hello\")],)], [\"messages\"]\n",
    ")\n",
    "chat_text = (\n",
    "    OpenAIChatCompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMessagesCol(\"messages\")\n",
    "    .setResponseFormat(\"text\")\n",
    "    .setOutputCol(\"chat_text\")\n",
    ")\n",
    "display(\n",
    "    chat_text.transform(messages_df)\n",
    "    .select(\"chat_text.choices.message.content\")\n",
    "    .show(truncate=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "messages_df = spark.createDataFrame(\n",
    "    [([Row(role=\"user\", content=\"Return a JSON object with key greeting\")],)],\n",
    "    [\"messages\"],\n",
    ")\n",
    "chat_json_obj = (\n",
    "    OpenAIChatCompletion()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setMessagesCol(\"messages\")\n",
    "    .setResponseFormat(\"json_object\")\n",
    "    .setOutputCol(\"chat_json_obj\")\n",
    ")\n",
    "display(\n",
    "    chat_json_obj.transform(messages_df)\n",
    "    .select(\"chat_json_obj.choices.message.content\")\n",
    "    .show(truncate=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Responses API Structured JSON with json_schema (flattened form)\n",
    "\n",
    "For the Responses API, you can pass a flattened `json_schema` dict: top-level `name`, `strict`, and `schema` keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIPrompt\n",
    "\n",
    "schema = {\n",
    "    \"type\": \"object\",\n",
    "    \"properties\": {\"answer\": {\"type\": \"string\"}},\n",
    "    \"required\": [\"answer\"],\n",
    "    \"additionalProperties\": False,\n",
    "}\n",
    "df = spark.createDataFrame(\n",
    "    [(\"France\", \"capital\"), (\"Germany\", \"capital\")], [\"text\", \"category\"]\n",
    ")\n",
    "prompt_flat = (\n",
    "    OpenAIPrompt()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setApiType(\"responses\")\n",
    "    .setApiVersion(\"2025-04-01-preview\")\n",
    "    .setPromptTemplate(\"What is the {category} of {text}.\")\n",
    "    .setResponseFormat(\n",
    "        {\n",
    "            \"type\": \"json_schema\",\n",
    "            \"name\": \"answer_schema\",\n",
    "            \"strict\": True,\n",
    "            \"schema\": schema,\n",
    "        }\n",
    "    )\n",
    "    .setOutputCol(\"out_flat\")\n",
    ")\n",
    "display(prompt_flat.transform(df).select(\"out_flat\").show(truncate=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Invalid: Bare json_schema String\n",
    "\n",
    "Attempting to set `responseFormat` to the bare string `json_schema` should raise an error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "\n",
    "    OpenAIPrompt().setResponseFormat(\"json_schema\")\n",
    "\n",
    "except Exception as e:\n",
    "\n",
    "    print(\"Expected error:\", e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with Usage Statistics\n",
    "\n",
    "The following examples show how to enable `returnUsage` for different OpenAI API surfaces and inspect the usage maps that are now returned alongside the model output.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Chat Completions with `returnUsage`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIPrompt\n",
    "\n",
    "chat_usage_prompt = (\n",
    "    OpenAIPrompt()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setApiType(\"chat_completions\")\n",
    "    .setReturnUsage(True)\n",
    "    .setPromptTemplate(\"Provide a fun fact about {topic}.\")\n",
    "    .setOutputCol(\"chat_with_usage\")\n",
    ")\n",
    "\n",
    "chat_usage_df = spark.createDataFrame([(\"Azure AI\",)], [\"topic\"])\n",
    "chat_usage_result = chat_usage_prompt.transform(chat_usage_df)\n",
    "display(\n",
    "    chat_usage_result.select(\"chat_with_usage.response\", \"chat_with_usage.usage\").show(\n",
    "        truncate=False\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Responses API with `returnUsage`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "responses_usage_prompt = (\n",
    "    OpenAIPrompt()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setDeploymentName(deployment_name)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setApiType(\"responses\")\n",
    "    .setApiVersion(\"2025-04-01-preview\")\n",
    "    .setReturnUsage(True)\n",
    "    .setPromptTemplate(\"List two key capabilities of {topic}.\")\n",
    "    .setOutputCol(\"responses_with_usage\")\n",
    ")\n",
    "\n",
    "responses_usage_df = spark.createDataFrame([(\"Azure OpenAI\",)], [\"topic\"])\n",
    "responses_usage_result = responses_usage_prompt.transform(responses_usage_df)\n",
    "display(\n",
    "    responses_usage_result.select(\n",
    "        \"responses_with_usage.response\", \"responses_with_usage.usage\"\n",
    "    ).show(truncate=False)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embeddings with `returnUsage`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from synapse.ml.services.openai import OpenAIEmbedding\n",
    "\n",
    "embedding_usage = (\n",
    "    OpenAIEmbedding()\n",
    "    .setSubscriptionKey(key)\n",
    "    .setCustomServiceName(service_name)\n",
    "    .setDeploymentName(deployment_name_embeddings)\n",
    "    .setReturnUsage(True)\n",
    "    .setTextCol(\"text\")\n",
    "    .setOutputCol(\"embedding_with_usage\")\n",
    ")\n",
    "\n",
    "embedding_usage_df = spark.createDataFrame(\n",
    "    [(\"Usage statistics help monitor token consumption.\",)], [\"text\"]\n",
    ")\n",
    "embedding_usage_result = embedding_usage.transform(embedding_usage_df)\n",
    "display(\n",
    "    embedding_usage_result.select(\n",
    "        \"embedding_with_usage.response\", \"embedding_with_usage.usage\"\n",
    "    ).show(truncate=False)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 2
   },
   "notebookName": "AI Services - OpenAI",
   "notebookOrigID": 2846029038781816,
   "widgets": {}
  },
  "kernel_info": {
   "name": "synapse_pyspark"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  },
  "save_output": true,
  "synapse_widget": {
   "state": {
    "4bd0e60b-98ae-4bfe-98ee-6f0399ceb456": {
     "persist_state": {
      "view": {
       "chartOptions": {
        "aggregationType": "count",
        "categoryFieldKeys": [
         "0"
        ],
        "chartType": "bar",
        "isStacked": false,
        "seriesFieldKeys": [
         "0"
        ]
       },
       "tableOptions": {},
       "type": "details"
      }
     },
     "sync_state": {
      "isSummary": false,
      "language": "scala",
      "table": {
       "rows": [
        {
         "0": "Once upon a time",
         "1": [
          " there was a girl who had a dream of becoming a writer.\n\nShe started writing short stories"
         ]
        },
        {
         "0": "Hello my name is",
         "1": [
          "***** and I have a question about my cat\n\nHello, thank you for bringing your question to"
         ]
        },
        {
         "0": "The best code is code thats",
         "1": [
          " not there\n\nCommenting your code is important. Not only does it help you remember what you"
         ]
        }
       ],
       "schema": [
        {
         "key": "0",
         "name": "prompt",
         "type": "string"
        },
        {
         "key": "1",
         "name": "text",
         "type": "ArrayType(StringType,true)"
        }
       ],
       "truncated": false
      }
     },
     "type": "Synapse.DataFrame"
    }
   },
   "version": "0.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
